# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#               2023 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
import copy
from typing import List, Optional

import json
import logging
import os
import torch
import yaml

import torch.optim as optim
import torch.distributed as dist

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
                                    CPUOffload, MixedPrecision,
                                    sharded_grad_scaler, ShardingStrategy)
try:
    import deepspeed
    from deepspeed.runtime.zero.stage_1_and_2 import (
        estimate_zero2_model_states_mem_needs_all_live)
    from deepspeed.runtime.zero.stage3 import (
        estimate_zero3_model_states_mem_needs_all_live)
    from deepspeed.utils.zero_to_fp32 import (
        convert_zero_checkpoint_to_fp32_state_dict)
except ImportError:
    pass


from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str,
                                tensor_to_scalar)
from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model,
                                    apply_fsdp_checkpointing,
                                    wenet_fsdp_wrap_policy)
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE
from wenet.utils.init_dataset import init_dataset


def add_model_args(parser):
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--checkpoint', help='checkpoint model')
    parser.add_argument('--tensorboard_dir',
                        default='tensorboard',
                        help='tensorboard log dir')
    parser.add_argument('--override_config',
                        action='append',
                        default=[],
                        help="override yaml config")
    parser.add_argument("--enc_init",
                        default=None,
                        type=str,
                        help="Pre-trained model to initialize encoder")
    parser.add_argument(
        '--enc_init_mods',
        default="encoder.",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="List of encoder modules \
                        to initialize ,separated by a comma")
    parser.add_argument(
        '--freeze_modules',
        default="",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help='free module names',
    )
    return parser


def add_trace_args(parser):
    parser.add_argument('--jit',
                        action='store_true',
                        default=False,
                        help='if use jit to trace model while training stage')
    parser.add_argument('--print_model',
                        action='store_true',
                        default=False,
                        help='print model')
    return parser


def add_dataset_args(parser):
    parser.add_argument('--data_type',
                        default='raw',
                        # choices=['raw', 'shard'],
                        help='train and cv data type')
    parser.add_argument('--train_data', required=True, help='train data file')
    parser.add_argument('--cv_data', required=True, help='cv data file')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--prefetch',
                        default=100,
                        type=int,
                        help='prefetch number')
    return parser


def add_lora_args(parser):
    '''Configure parameters for LoRA fine-tuning. Set use_lora and
       only_optimize_lora to true to enable LoRA functionality.
       LoRA will be injected to model through (lora_modules, lora_attn_attr,
       lora_list).
       LoRA weights will be merged after calling model.eval()
       (or model.train(mode=False)).
       LoRA weights need to be loaded after fine-tuning with DeepSpeed.
    '''
    parser.add_argument("--use_lora",
                        default=False,
                        type=bool,
                        help="whether use the lora finetune.")
    parser.add_argument("--only_optimize_lora",
                        default=False,
                        type=bool,
                        help="freeze all other parameters and only optimize \
                        LoRA-related parameters.")
    parser.add_argument(
        '--lora_modules',
        default="encoder.encoders",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help='modules names needs inject lora',
    )
    parser.add_argument(
        "--lora_attn_attr",
        default="self_attn,src_attn",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="lora_attn_attr.")
    parser.add_argument(
        "--lora_list",
        default="linear_out,linear_q,linear_k,linear_v",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="lora module list.")
    parser.add_argument("--lora_rank",
                        default=8,
                        type=int,
                        help="lora rank num.")
    parser.add_argument("--lora_alpha",
                        default=8,
                        type=int,
                        help="lora scale param, scale=lora_alpha/lora_rank.")
    parser.add_argument("--lora_dropout",
                        default=0,
                        type=float,
                        help="lora dropout param.")
    parser.add_argument("--lora_ckpt_path",
                        default=None,
                        type=str,
                        help="lora checkpoint path.")
    parser.add_argument("--lora_reinit",
                        default=False,
                        type=bool,
                        help="whether use the lora init, default is zero init.")
    parser.add_argument('--lora_init_yaml',
                        default="wenet/finetune/lora/config.yaml",
                        type=str,
                        help='Path to the configuration YAML file')
    return parser


def add_ddp_args(parser):
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo', "hccl"],
                        help='distributed backend')
    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--fp16_grad_sync',
                        action='store_true',
                        default=False,
                        help='Use fp16 gradient sync for ddp')
    return parser


def add_deepspeed_args(parser):
    parser.add_argument('--timeout',
                        default=30,
                        type=int,
                        help='timeout (in seconds) of wenet_join. ' +
                        '30s for aishell & 300s for wenetspeech')
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help='local rank passed from distributed launcher')
    parser.add_argument('--deepspeed.save_states',
                        dest='save_states',
                        default='model_only',
                        choices=['model_only', 'model+optimizer'],
                        help='save model/optimizer states')
    # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser
    try:
        parser = deepspeed.add_config_arguments(parser)
    except Exception as e:
        print(e)
    return parser


def add_fsdp_args(parser):
    parser.add_argument(
        '--dtype',
        default='fp32',
        choices=['fp32', 'fp16', 'bf16'],
        help='when amp is used, dtype is automatically set to fp16.\
        this arg has no effect when deepspeed is enabled.')
    parser.add_argument(
        '--fsdp_cpu_offload',
        default=False,
        type=bool,
        help='whether to offload parameters to CPU',
    )
    parser.add_argument(
        '--fsdp_sync_module_states',
        type=bool,
        default=True,
        help='\
        each FSDP module will broadcast module parameters and buffers from \
        rank 0 to ensure that they are replicated across ranks',
    )
    parser.add_argument(
        '--fsdp_sharding_strategy',
        default='zero2',
        # TODO(Mddct): pipeline and model parallel (3-D parallelism)
        choices=['no_shard', 'model', 'zero2', 'zero3'],
        help='Sharding strategy for FSDP. Choose from the following options:\n'
        '  - "no_shard": Equivalent to DistributedDataParallel (DDP).\n'
        '  - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n'
        '  - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n'
        '  - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n'
        'For more information, refer to the FSDP API documentation.')
    return parser


def init_distributed(args):
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = int(os.environ.get('RANK', 0))
    logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
                 ', rank {}, world_size {}'.format(rank, world_size))
    if args.train_engine in ["torch_ddp", "torch_fsdp"]:
        if "cuda" in args.device:
            torch.cuda.set_device(local_rank)
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            torch.npu.set_device(local_rank)
        else:
            logging.error("not supported device: {}".format(args.device))
        dist.init_process_group(args.dist_backend)
    elif args.train_engine == "deepspeed":
        deepspeed.init_distributed(dist_backend=args.dist_backend)
    else:
        logging.error("not supported engine: {}".format(args.train_engine))
    return world_size, local_rank, rank


def check_modify_and_save_config(args, configs, symbol_table):
    if args.train_engine in ["torch_ddp", "torch_fsdp"]:
        if args.use_amp:
            configs["dtype"] = "fp16"
            args.dtype = 'fp16'
        else:
            configs["dtype"] = args.dtype
    elif args.train_engine == "deepspeed":
        # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom
        #   dataset, we need to manually ensure that the data is evenly distributed
        #   across all processe. we impl `train_utils.py::wenet_join` for this func
        #   ref: https://github.com/microsoft/DeepSpeed/issues/2223
        #
        # NOTE(xsong):  We also need to keep:
        #       1. `train_micro_batch_size_per_gpu == 1`
        #       2. `accum_grad (in train_confomrer.yaml)
        #               == gradient_accumulation_steps (in ds_config.json)`
        #       3. `grad_clip (in train_confomrer.yaml)
        #               == gradient_clipping (in ds_config.json)`
        #   The reason for such consistence checking lies in that deepspeed's native
        #   dataloader uses PyTorch's torch.utils.data.DistributedSampler which does
        #   not support IterableDataset, IterableDataset is extremly useful in large
        #   scale training because it lets you stream the data without having to
        #   download the complete dataset.
        #       ref: https://github.com/microsoft/DeepSpeed/issues/1371
        #           https://github.com/microsoft/DeepSpeed/issues/285
        #   To make deepspeed training compatible with IterableDataset, we have to
        #   use custom dataloader instead of deepspeed's native loader and thus we
        #   should configure batchsize in train_confomrer.yaml instead of
        #   ds_config.json. On the contrary, gradient accumulation / clipping should be
        #   configured in ds_config.json since they will be handled by ds automatically.
        #       ref: https://github.com/microsoft/DeepSpeed/issues/62
        with open(args.deepspeed_config, 'r') as fin:
            ds_configs = json.load(fin)
        if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
            configs["dtype"] = "fp16"
        elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
            configs["dtype"] = "bf16"
        else:
            configs["dtype"] = "fp32"
        assert ds_configs["train_micro_batch_size_per_gpu"] == 1
        assert ds_configs["gradient_accumulation_steps"] == configs[
            'accum_grad']
        assert ds_configs["gradient_clipping"] == configs['grad_clip']
        assert ds_configs["steps_per_print"] == configs['log_interval']

    if args.use_lora:
        configs['lora_conf'] = {}
        configs['lora_conf']['lora_modules'] = args.lora_modules
        configs['lora_conf']['lora_attn_attr'] = args.lora_attn_attr
        configs['lora_conf']['lora_list'] = args.lora_list
        configs['lora_conf']['lora_rank'] = args.lora_rank
        configs['lora_conf']['lora_alpha'] = args.lora_alpha
        configs['lora_conf']['lora_dropout'] = args.lora_dropout

    if configs["model"] == 'asr_model':
        if 'input_dim' not in configs:
            if 'fbank_conf' in configs['dataset_conf']:
                input_dim = configs['dataset_conf']['fbank_conf'][
                    'num_mel_bins']
            elif 'log_mel_spectrogram_conf' in configs['dataset_conf']:
                input_dim = configs['dataset_conf'][
                    'log_mel_spectrogram_conf']['num_mel_bins']
            else:
                input_dim = configs['dataset_conf']['mfcc_conf'][
                    'num_mel_bins']
        else:
            input_dim = configs['input_dim']

        configs['input_dim'] = input_dim

    configs, _ = get_blank_id(configs, symbol_table)
    configs['output_dim'] = configs['vocab_size']

    configs['train_engine'] = args.train_engine
    configs['use_amp'] = args.use_amp
    configs['model_dir'] = args.model_dir
    configs['save_states'] = args.save_states

    # Save configs to model_dir/train.yaml for inference and export
    if int(os.environ.get('RANK', 0)) == 0:
        saved_config_path = os.path.join(args.model_dir, 'train.yaml')
        with open(saved_config_path, 'w') as fout:
            data = yaml.dump(configs)
            fout.write(data)

    if configs["model_conf"].get("apply_non_blank_embedding", False):
        logging.warning('Had better load a well trained model'
                     'if apply_non_blank_embedding is true !!!')

    return configs


def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
    generator = torch.Generator()
    generator.manual_seed(seed)

    # if save_interval in configs, steps mode else epoch mode
    if "save_interval" in configs:
        configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100)
    conf = configs['dataset_conf']
    dataset_type = configs.get('dataset', 'asr')
    configs['vocab_size'] = tokenizer.vocab_size()
    train_dataset = init_dataset(dataset_type,
                                 args.data_type,
                                 args.train_data,
                                 tokenizer,
                                 conf,
                                 True,
                                 split='train')
    tag = configs["init_infos"].get("tag", "init")
    train_dataset.set_epoch(configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) - 1)
    cv_conf = copy.deepcopy(conf)
    cv_conf['split_num'] = 1
    cv_dataset = init_dataset(dataset_type,
                              args.data_type,
                              args.cv_data,
                              tokenizer,
                              cv_conf,
                              partition=False,
                              split='cv')

    # NOTE(xcsong): Why we prefer persistent_workers=True ?
    #   https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=None,
                                   pin_memory=args.pin_memory,
                                   num_workers=args.num_workers,
                                   persistent_workers=True,
                                   generator=generator,
                                   prefetch_factor=args.prefetch)
    cv_data_loader = DataLoader(cv_dataset,
                                batch_size=None,
                                pin_memory=args.pin_memory,
                                num_workers=args.num_workers,
                                persistent_workers=True,
                                generator=generator,
                                prefetch_factor=args.prefetch)
    return train_dataset, cv_dataset, train_data_loader, cv_data_loader


def wrap_cuda_model(args, model, configs=None):
    local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    if hasattr(model, 'encoder'):
        grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False)
    else:
        grad_ckpt = False
    if args.train_engine == "torch_ddp":  # native pytorch ddp
        device = torch.device(args.device)
        model.to(device)
        # model = torch.nn.parallel.DistributedDataParallel(
        #     model, find_unused_parameters=not grad_ckpt)
        model = torch.nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=True)
    elif args.train_engine == "deepspeed":  # deepspeed
        # NOTE(xcsong): look in detail how the memory estimator API works:
        #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
        if int(os.environ.get('RANK', 0)) == 0:
            logging.info("Estimating model states memory needs (zero2)...")
            estimate_zero2_model_states_mem_needs_all_live(
                model,
                num_gpus_per_node=local_world_size,
                num_nodes=world_size // local_world_size)
            logging.info("Estimating model states memory needs (zero3)...")
            estimate_zero3_model_states_mem_needs_all_live(
                model,
                num_gpus_per_node=local_world_size,
                num_nodes=world_size // local_world_size)
        device = torch.device(args.device)  # Init device later
        pass  # Init DeepSpeed later
    elif args.train_engine == 'torch_fsdp':
        assert configs is not None
        mixed_precision_dtype = {
            'fp32': torch.float32,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[configs['dtype']]

        sharding_strategy = {
            'model': ShardingStrategy.SHARD_GRAD_OP,
            'zero2': ShardingStrategy.SHARD_GRAD_OP,
            'zero3': ShardingStrategy.FULL_SHARD,
            'no_shard': ShardingStrategy.NO_SHARD,
        }[args.fsdp_sharding_strategy]
        wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy)
        layer_types = check_gradient_checkpoint(model)
        if "cuda" in args.device:
            device_id = torch.cuda.current_device()
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            device_id = torch.npu.current_device()
        else:
            logging.error("not supported device: {}".format(args.device))
        model = FSDP(
            model,
            auto_wrap_policy=wrap_policy,
            cpu_offload=CPUOffload(offload_params=True)
            if args.fsdp_cpu_offload is True else None,
            mixed_precision=MixedPrecision(
                param_dtype=mixed_precision_dtype,
                reduce_dtype=mixed_precision_dtype,
                buffer_dtype=mixed_precision_dtype,
            ),
            sharding_strategy=sharding_strategy,
            limit_all_gathers=True,
            use_orig_params=True,
            sync_module_states=args.fsdp_sync_module_states,
            # init_distributed is called (torch.cuda.set_device),
            # we should set device_id, see FSDP api
            device_id=device_id)
        apply_fsdp_checkpointing(model, layer_types)
        device = torch.device(args.device)
    else:
        logging.error("not supported engine: {}".format(args.train_engine))
    if args.train_engine in ["torch_fsdp", "torch_ddp"]:
        if args.fp16_grad_sync:
            from torch.distributed.algorithms.ddp_comm_hooks import (
                default as comm_hooks, )
            model.register_comm_hook(state=None,
                                     hook=comm_hooks.fp16_compress_hook)

    return model, device


def init_optimizer_and_scheduler(args, configs, model):
    groups = []
    lr = configs['optim_conf'].get('lr')
    if isinstance(lr, List):
        assert configs['scheduler'] == 'warmuplr'
        modules_m = configs['optim_conf']['modules']
        assert isinstance(modules_m, List)
        assert len(modules_m) + 1 == len(lr)
        special_param_ids = set()
        rest_params = []
        for (i, m_str) in enumerate(modules_m):
            sub_module = get_nested_attribute(model, m_str)
            subs_params = []
            for _, sub_params in sub_module.named_parameters():
                subs_params.append(sub_params)
                special_param_ids.add(id(sub_params))
            groups.append({'params': subs_params, 'lr': lr[i]})
        # other model's parameters
        for _, param in model.named_parameters():
            if id(param) not in special_param_ids:
                rest_params.append(param)
        groups.append({'params': rest_params, 'lr': lr[-1]})

    params = groups if len(groups) > 0 else model.parameters()
    optim_conf = copy.deepcopy(configs['optim_conf'])
    if 'modules' in optim_conf:
        del optim_conf['modules']
    if isinstance(lr, List):
        optim_conf['lr'] = lr[-1]
    if configs['optim'] == 'adam':
        optimizer = optim.Adam(params, **optim_conf)
    elif configs['optim'] == 'adamw':
        optimizer = optim.AdamW(params, **optim_conf)
    else:
        raise ValueError("unknown optimizer: " + configs['optim'])

    scheduler_type = None
    if configs['scheduler'] == 'warmuplr':
        scheduler_type = WarmupLR
        scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
    elif configs['scheduler'] == 'NoamHoldAnnealing':
        scheduler_type = NoamHoldAnnealing
        scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
    else:
        raise ValueError("unknown scheduler: " + configs['scheduler'])

    # NOTE(xcsong): Custom optimizer might yield poor performance when
    #   zero-offload is enabled, if you do want to offload optimizer to CPU,
    #   please set optimizer in ds_config.json, see:
    #   (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters)
    if args.train_engine == "deepspeed":
        with open(args.deepspeed_config, 'r') as fin:
            ds_configs = json.load(fin)
        if "optimizer" in ds_configs:
            # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
            # extremely useful when enable cpu_offload, DeepspeedCpuAdam
            # could be 4~5x faster than torch native adam
            optimizer = None
            if "scheduler" in ds_configs:
                scheduler = None
            else:

                def scheduler(opt):
                    return scheduler_type(opt, **configs['scheduler_conf'])

        model, optimizer, _, scheduler = deepspeed.initialize(
            args=args,
            model=model,
            optimizer=optimizer,
            lr_scheduler=scheduler,
            model_parameters=model.parameters())

    step = configs.get("init_infos", {}).get("step", -1)
    scheduler.set_step(step)
    return model, optimizer, scheduler


def trace_and_print_model(args, model):
    # !!!IMPORTANT!!!
    # Try to export the model by script, if fails, we should refine
    # the code to satisfy the script export requirements
    if int(os.environ.get('RANK', 0)) == 0:
        if args.jit:
            script_model = torch.jit.script(model)
            script_model.save(os.path.join(args.model_dir, 'init.zip'))
        if args.print_model:
            print(model)
            num_params = sum(p.numel() for p in model.parameters())
            print('the number of model params: {:,d}'.format(num_params))


def init_summarywriter(args):
    writer = None
    if int(os.environ.get('RANK', 0)) == 0:
        os.makedirs(args.model_dir, exist_ok=True)
        exp_id = os.path.basename(args.model_dir)
        writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
    return writer


def init_scaler(args):
    scaler = None
    if args.use_amp:
        if "cuda" in args.device:
            scaler = torch.cuda.amp.GradScaler()
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            scaler = torch.npu.amp.GradScaler()
        else:
            logging.error("not supported device: {}".format(args.device))
    elif args.train_engine == 'torch_fsdp':
        # why bf16 don't need scaler:
        # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596
        if args.dtype in ['fp16']:
            scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True)
    return scaler


def save_model(model, info_dict):
    rank = int(os.environ.get('RANK', 0))
    tag = info_dict["tag"]
    model_dir = info_dict["model_dir"]
    save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
    # save ckpt
    if info_dict["train_engine"] == "deepspeed":
        # NOTE(xcsong): All ranks should call this API, but only rank 0
        #   save the general model params. see:
        #   https://github.com/microsoft/DeepSpeed/issues/2993
        with torch.no_grad():
            model.save_checkpoint(save_dir=model_dir,
                                  tag=tag,
                                  client_state=info_dict)
            if info_dict["save_states"] == "model_only" and rank == 0:
                convert_zero_checkpoint_to_fp32_state_dict(model_dir,
                                                           save_model_path,
                                                           tag=tag)
                os.system("rm -rf {}/{}".format(model_dir, tag))

    elif info_dict['train_engine'] == "torch_fsdp":
        fsdp_save_model(model, save_model_path, info_dict)
    elif rank == 0:
        # NOTE(xcsong): For torch_ddp, only rank-0 should call this.
        save_checkpoint(model, save_model_path, info_dict)
    # save yaml
    if rank == 0:
        with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:
            data = yaml.dump(info_dict)
            fout.write(data)


def wenet_join(group_join, info_dict):
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = int(os.environ.get('RANK', 0))
    train_engine = info_dict.get('train_engine', "torch_ddp")

    if info_dict["batch_idx"] == 0 or train_engine == "torch_ddp":
        # NOTE(xcsong): skip first batch because its processing time includes
        #   dataloader initialization time, which may exceed 30 seconds
        return False

    try:
        # NOTE(xcsong): Why we need a new group?
        #   Because Deepspeed has its own group where all the relevant communication
        #   operations are executed. If we add a communication operation that is not
        #   managed by Deepspeed in this group, it's highly likely to cause
        #   communication chaos, resulting in hard-to-troubleshoot hangs.
        dist.monitored_barrier(group=group_join,
                               timeout=group_join.options._timeout)
    except RuntimeError as e:
        logging.info("Detected uneven workload distribution: {}\n".format(e) +
                     "Break current worker to manually join all workers, " +
                     "world_size {}, current rank {}, current local_rank {}\n".
                     format(world_size, rank, local_rank))
        return True

    return False


def batch_forward(model, batch, scaler, info_dict, device):
    train_engine = info_dict.get('train_engine', "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)

    dtype = info_dict.get("dtype", "fp32")
    if dtype == "fp16":
        dtype = torch.float16
    elif dtype == "bf16":
        dtype = torch.bfloat16
    else:  # fp32
        dtype = None

    # autocast context
    # The more details about amp can be found in
    # https://pytorch.org/docs/stable/notes/amp_examples.html
    amp_autocast = torch.cuda.amp.autocast
    if "npu" in device.__str__() and TORCH_NPU_AVAILABLE:
        amp_autocast = torch.npu.amp.autocast
    autocast = {
        "deepspeed":
        amp_autocast(enabled=dtype is not None,
                     dtype=dtype,
                     cache_enabled=False),
        "torch_ddp":
        amp_autocast(enabled=scaler is not None),
        "torch_fsdp":
        amp_autocast(enabled=True, dtype=dtype)
        if dtype is not None else nullcontext()
    }[train_engine]
    with autocast:
        loss_dict = model(batch, device)

    info_dict['loss_dict'] = loss_dict
    return info_dict


def batch_backward(model, scaler, info_dict):
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)
    use_amp = info_dict.get('use_amp', False)
    if use_amp:
        assert scaler is not None
    loss = info_dict['loss_dict']['loss']

    if train_engine == "deepspeed":
        # NOTE(xcsong): `model.backward(loss)` is equivalent to
        #               `scale_loss_wrt_accum_grad + loss.backward()`
        #   ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
        scaled_loss = model.backward(loss)
    else:
        assert train_engine in ["torch_ddp", "torch_fsdp"]
        scaled_loss = loss / accum_grad
        if scaler is not None:
            # fp16 (amp and fsdp)
            scaler.scale(scaled_loss).backward()
        else:
            # float32  (ddp and fsdp)
            # bf16 (fsdp)
            scaled_loss.backward()

    info_dict['loss_dict']['loss'] = scaled_loss
    for loss_name, loss_value in info_dict['loss_dict'].items():
        if loss_value is not None:
            info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value)

    return info_dict


def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
    rank = int(os.environ.get('RANK', 0))
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)
    use_amp = info_dict.get('use_amp', False)
    clip = info_dict.get('grad_clip', 50.0)
    batch_idx = info_dict["batch_idx"]
    if use_amp:
        assert scaler is not None

    grad_norm = 0.0
    if train_engine == "deepspeed":
        # NOTE(xcsong): The step() function in DeepSpeed engine updates the
        #   model parameters as well as the learning rate.
        #   Zeroing the gradients is handled automatically by
        #   DeepSpeed after the weights have been updated using a mini-batch.
        #   DeepSpeed also performs gradient averaging automatically at the
        #   gradient accumulation boundaries and addresses clip_grad_norm internally.
        #   `ds_model.step() =  clip_grad_norm_() + optimizer.step()
        #                       + optimizer.zero_grad() + scheduler.step()`
        #   ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
        info_dict["is_gradient_accumulation_boundary"] = \
            model.is_gradient_accumulation_boundary()
        model.step()
        grad_norm = model.get_global_grad_norm()
        if grad_norm is None:
            grad_norm = 0.0
    elif (batch_idx + 1) % accum_grad == 0:
        # Use mixed precision training
        # fp16 (ddp fsdp)
        if scaler is not None:
            scaler.unscale_(optimizer)
            if train_engine == "torch_ddp":
                grad_norm = clip_grad_norm_(model.parameters(), clip)
            else:
                # fsdp
                grad_norm = model.clip_grad_norm_(clip)
            # Must invoke scaler.update() if unscale_() is used in
            # the iteration to avoid the following error:
            #   RuntimeError: unscale_() has already been called
            #   on this optimizer since the last update().
            # We don't check grad here since that if the gradient
            # has inf/nan values, scaler.step will skip
            # optimizer.step().
            scaler.step(optimizer)
            scaler.update()
        else:
            if train_engine == "torch_ddp":
                grad_norm = clip_grad_norm_(model.parameters(), clip)
            else:
                grad_norm = model.clip_grad_norm_(clip)
            if torch.isfinite(grad_norm):
                optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups]
    info_dict["grad_norm"] = tensor_to_scalar(grad_norm)

    return info_dict


def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
    tag = info_dict["tag"]
    step = info_dict["step"]
    batch_idx = info_dict["batch_idx"]
    loss_dict = info_dict['loss_dict']
    epoch = info_dict.get('epoch', 0)
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1
    log_interval = info_dict.get('log_interval', 10)
    lrs = info_dict.get("lrs", [0.0])
    is_gradient_accumulation_boundary = info_dict.get(
        "is_gradient_accumulation_boundary", False)

    rank = int(os.environ.get('RANK', 0))
    # TRAIN Tensorboard
    if tag == "TRAIN" and rank == 0 and writer is not None:
        if (train_engine == "deepspeed" and is_gradient_accumulation_boundary
            ) or (train_engine in ["torch_ddp", "torch_fsdp"] and
                  (batch_idx + 1) % accum_grad == 0):
            writer.add_scalar('train/train_loss',
                              tensor_to_scalar(loss_dict['loss']) * accum_grad,
                              step)
            writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step)
            for name, value in loss_dict.items():
                if name != 'loss' and value is not None:
                    writer.add_scalar('train/{}'.format(name),
                                      tensor_to_scalar(value), step)
            # lr
            for i, lr in enumerate(lrs):
                writer.add_scalar('train/lr_{}'.format(i), lr, step)
    # CV Tensorboard
    elif "step_" in tag and rank == 0 and writer is not None:
        for name, value in loss_dict.items():
            writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value),
                              step)
        logging.info(
            'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
                epoch, step + 1, lrs_to_str(lrs),
                tensor_to_scalar(loss_dict["loss"]), rank,
                tensor_to_scalar(loss_dict["acc"])))
        return

    # TRAIN & CV, Shell log (stdout)
    if (batch_idx + 1) % log_interval == 0:
        log_str = '{} | '.format(tag)
        if timer is not None:
            timer_step = step
            if info_dict.get("cv_step", None) is not None:
                timer_step = info_dict['cv_step']
            steps_per_second = timer.steps_per_second(timer_step)
            log_str += 'steps/sec {:.3f}| '.format(steps_per_second)
        log_str += 'Batch {}/{} loss {:.6f} '.format(
            epoch, batch_idx + 1 if 'save_interval' not in info_dict else
            (step + 1) * accum_grad,
            tensor_to_scalar(loss_dict['loss']) * accum_grad)
        for name, value in loss_dict.items():
            if name != 'loss' and value is not None:
                log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value))
        if tag == "TRAIN":
            log_str += 'lr {} grad_norm {:.6f} rank {}'.format(
                lrs_to_str(lrs), info_dict['grad_norm'], rank)
        logging.debug(log_str)


def log_per_epoch(writer, info_dict):
    epoch = info_dict["epoch"]
    loss_dict = info_dict["loss_dict"]
    lrs = info_dict['lrs']
    rank = int(os.environ.get('RANK', 0))
    step = info_dict["step"]
    logging.info(
        'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
            epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]),
            rank, tensor_to_scalar(loss_dict["acc"])))

    if int(os.environ.get('RANK', 0)) == 0:
        for i, lr in enumerate(info_dict["lrs"]):
            writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch)
        for name, value in loss_dict.items():
            writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value),
                              epoch)


def freeze_modules(model, args):
    for name, param in model.named_parameters():
        for module_name in args.freeze_modules:
            if module_name in name:
                param.requires_grad = False
                logging.debug("{} module is freezed".format(name))


def reinit_lora(model, args, configs, tokenizer, seed=777):
    from tqdm import tqdm
    from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules
    from wenet.finetune.lora.layers import LoRALayer
    from types import SimpleNamespace

    logging.info("reinit lora modules.")
    with open(args.lora_init_yaml, 'r') as file:
        lora_config = yaml.safe_load(file)

    generator = torch.Generator()
    generator.manual_seed(seed)
    dataset_conf = copy.deepcopy(configs['dataset_conf'])
    dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size']
    dataset_type = configs.get('dataset', 'asr')
    dataset = init_dataset(dataset_type, args.data_type, args.train_data,
                           tokenizer, dataset_conf, True)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            pin_memory=args.pin_memory,
                            num_workers=args.num_workers,
                            persistent_workers=True,
                            generator=generator,
                            prefetch_factor=args.prefetch)
    additional_kwargs = {}
    if lora_config["init_config"]["mode"] == "gradient":
        named_grads = estimate_gradient(model, dataloader,
                                        lora_config['init_iters'])
        additional_kwargs["named_grads"] = named_grads
    lora_config = SimpleNamespace(**lora_config["init_config"])
    for name, module in tqdm(
        model.named_modules(),
        desc="Reinitializing Lora",
        total=len(list(model.named_modules())),
    ):
        if isinstance(module, LoRALayer):
            reinit_lora_modules(name, module, lora_config, **additional_kwargs)
    # lora_init_model needs to be saved, w0 = w0 - A0 * B0
    save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"),
                    infos={"tag": "lora_init", **configs})
